{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# -*- coding:utf-8 -*-\n",
    "\n",
    "from __future__ import print_function\n",
    "from __future__ import division\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "from tensorflow.python.layers.core import Dense\n",
    "import sys\n",
    "import os\n",
    "import time\n",
    "from tqdm import tqdm \n",
    "\n",
    "sys.path.append('..')\n",
    "from data_process.sequence_example_helpers import get_padded_batch\n",
    "from data_process.get_ids import get_vocab_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W_encoder_embedding.shape= (87509, 256)\n",
      "W_decoder_embedding.shape= (129939, 256)\n",
      "vocab file ../data/vocab.en, vocab size = 87509\n",
      "vocab file ../data/vocab.zh, vocab size = 129939\n",
      "EN_UNK_ID= 1\n"
     ]
    }
   ],
   "source": [
    "W_encoder_embedding = np.load('../data/en_embeddings.npy')\n",
    "W_decoder_embedding = np.load('../data/zh_embeddings.npy')\n",
    "print('W_encoder_embedding.shape=', W_encoder_embedding.shape)\n",
    "print('W_decoder_embedding.shape=', W_decoder_embedding.shape)\n",
    "\n",
    "\n",
    "dict_en_word2id, dict_en_id2word = get_vocab_map('../data/vocab.en')\n",
    "dict_zh_word2id, dict_zh_id2word = get_vocab_map('../data/vocab.zh')\n",
    "\n",
    "GO_ID = dict_zh_word2id['<GO>']\n",
    "EOS_ID = dict_zh_word2id['<EOS>']\n",
    "EN_UNK_ID = dict_en_word2id['<UNK>']\n",
    "EN_PAD_ID = dict_en_word2id['<PAD>']\n",
    "print('EN_UNK_ID=', EN_UNK_ID)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# place_holder 部分\n",
    "\n",
    "# lr = tf.placeholder(tf.float32, name='learning_rate')\n",
    "# batch_size = tf.placeholder(tf.int64, [])\n",
    "num_layers = 2\n",
    "rnn_size = 256\n",
    "start_lr = 1e-3\n",
    "decay_steps = 10000\n",
    "decay_rate = 0.9\n",
    "global_step = tf.Variable(0, trainable=False, dtype=tf.int64)\n",
    "lr = tf.train.exponential_decay(start_lr, global_step, decay_steps, decay_rate, staircase=True)\n",
    "batch_size = 32\n",
    "\n",
    "# 输入\n",
    "with tf.variable_scope('inputs'): \n",
    "    inputs = tf.placeholder(tf.int32, [None, None], name='inputs')\n",
    "    targets = tf.placeholder(tf.int32, [None, None], name='targets')\n",
    "    # 定义target序列最大长度（之后target_sequence_length和source_sequence_length会作为feed_dict的参数）\n",
    "    target_sequence_length = tf.placeholder(tf.int32, (None,), name='target_sequence_length')\n",
    "    source_sequence_length = tf.placeholder(tf.int32, (None,), name='source_sequence_length')\n",
    "\n",
    "    max_target_sequence_length = tf.reduce_max(target_sequence_length, name='max_target_len')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"rnn/transpose:0\", shape=(?, ?, 256), dtype=float32)\n",
      "(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_2:0' shape=(?, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 256) dtype=float32>))\n"
     ]
    }
   ],
   "source": [
    "# embeddings \n",
    "with tf.variable_scope('embeddings'):\n",
    "    encoder_embeddings = tf.get_variable(name='encoder_embedding', shape=W_encoder_embedding.shape,\n",
    "                                             initializer=tf.constant_initializer(W_encoder_embedding), trainable=True)\n",
    "    decoder_embeddings = tf.get_variable(name='decoder_embedding', shape=W_decoder_embedding.shape,\n",
    "                                             initializer=tf.constant_initializer(W_decoder_embedding), trainable=True)\n",
    "\n",
    "# encoder\n",
    "encoder_embed_input = tf.nn.embedding_lookup(encoder_embeddings, inputs)\n",
    "\n",
    "# RNN cell\n",
    "def get_lstm_cell(rnn_size):\n",
    "    lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))\n",
    "    return lstm_cell\n",
    "\n",
    "encoder_cell = tf.contrib.rnn.MultiRNNCell([get_lstm_cell(rnn_size) for _ in range(num_layers)])\n",
    "\n",
    "encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_embed_input, \n",
    "                                                  sequence_length=source_sequence_length, dtype=tf.float32)\n",
    "\n",
    "# encoder_output\n",
    "#   If time_major == False (default), this will be a `Tensor` shaped:\n",
    "#     `[batch_size, max_time, cell.output_size]`.\n",
    "#   If time_major == True, this will be a `Tensor` shaped:\n",
    "#     `[max_time, batch_size, cell.output_size]`.\n",
    "print(encoder_output) \n",
    "print(encoder_state)  # encoder_state 为最后一步的输出状态"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BasicDecoderOutput(rnn_output=<tf.Tensor 'decode/decoder/transpose:0' shape=(?, ?, 129939) dtype=float32>, sample_id=<tf.Tensor 'decode/decoder/transpose_1:0' shape=(?, ?) dtype=int32>)\n",
      "Tensor(\"decode/decoder/transpose:0\", shape=(?, ?, 129939), dtype=float32)\n",
      "Tensor(\"decode/decoder/transpose_1:0\", shape=(?, ?), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# decoder\n",
    "target_vocab_size = W_decoder_embedding.shape[0]\n",
    "decoder_input = targets[:, :-1]  # 输入去掉最后的一个符号\n",
    "decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)\n",
    "decoder_cell = tf.contrib.rnn.MultiRNNCell([get_lstm_cell(rnn_size) for _ in range(num_layers)])\n",
    "output_layer = Dense(target_vocab_size,\n",
    "                         kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))\n",
    "\n",
    "# 4. Training decoder\n",
    "with tf.variable_scope(\"decode\"):\n",
    "    # 得到help对象\n",
    "    training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,\n",
    "                                                        sequence_length=target_sequence_length,\n",
    "                                                        time_major=False)\n",
    "    # 构造decoder\n",
    "    training_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,\n",
    "                                                       training_helper,\n",
    "                                                       encoder_state,\n",
    "                                                       output_layer) \n",
    "    training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,\n",
    "                                                                   impute_finished=True,\n",
    "                                                                   maximum_iterations=max_target_sequence_length)\n",
    "# 5. Predicting decoder\n",
    "# 与training共享参数\n",
    "with tf.variable_scope(\"decode\", reuse=True):\n",
    "    # 创建一个常量tensor并复制为batch_size的大小\n",
    "    start_tokens = tf.tile(tf.constant([GO_ID], dtype=tf.int32), [batch_size], name='start_tokens')\n",
    "    predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,\n",
    "                                                            start_tokens,\n",
    "                                                            EOS_ID)\n",
    "    predicting_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,\n",
    "                                                    predicting_helper,\n",
    "                                                    encoder_state,\n",
    "                                                    output_layer)\n",
    "    #  returns: (final_outputs, final_state, final_sequence_lengths)  \n",
    "    predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,\n",
    "                                                        impute_finished=True,\n",
    "                                                        maximum_iterations=max_target_sequence_length)\n",
    "      \n",
    "print(training_decoder_output)\n",
    "print(training_decoder_output.rnn_output)  # [batch_size, max_time, target_vocab_size]，每步输出的类别概率\n",
    "print(training_decoder_output.sample_id)   # [batch_size, max_time] ? 应该是每一步的预测 id "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 损失函数以及优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"logits:0\", shape=(?, ?, 129939), dtype=float32)\n",
      "Tensor(\"predictions:0\", shape=(32, ?), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# 损失函数\n",
    "training_logits = tf.identity(training_decoder_output.rnn_output, 'logits')  # 训练的时候需要使用 logits 来计算损失\n",
    "training_preds = tf.identity(training_decoder_output.sample_id, 'train_predictins')\n",
    "predicting_logits = tf.identity(predicting_decoder_output.sample_id, name='predictions')    # 预测的时候只需要得到 id 结果。\n",
    "# 返回 0.0, 1.0 的一个矩阵，标注 target_input 的每一位是否为 <PAD>\n",
    "masks = tf.sequence_mask(target_sequence_length, max_target_sequence_length, dtype=tf.float32, name='masks')   \n",
    "with tf.name_scope(\"optimization\"):\n",
    "    # Loss function\n",
    "    training_targets = targets[:, 1:]\n",
    "    cost = tf.contrib.seq2seq.sequence_loss(\n",
    "        training_logits,\n",
    "        training_targets,   # 去掉前面的 <GO>\n",
    "        masks)\n",
    "\n",
    "    # Optimizer\n",
    "    optimizer = tf.train.AdamOptimizer(lr)\n",
    "\n",
    "    # Gradient Clipping\n",
    "    gradients = optimizer.compute_gradients(cost)\n",
    "    capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]\n",
    "    train_op = optimizer.apply_gradients(capped_gradients, global_step=global_step)\n",
    "\n",
    "    \n",
    "saver = tf.train.Saver(max_to_keep=5)\n",
    "\n",
    "print(training_logits)     # 每一步预测的概率分布\n",
    "print(predicting_logits)   # 每一步预测的最大值对应的id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Counting records in ../data/train_tfrecords/train-14-of-32.tfrecord.\n",
      "INFO:tensorflow:Number of records is at least 1000.\n",
      "INFO:tensorflow:[<tf.Tensor 'random_shuffle_queue_Dequeue:0' shape=(?,) dtype=int64>, <tf.Tensor 'random_shuffle_queue_Dequeue:1' shape=(?,) dtype=int64>, <tf.Tensor 'random_shuffle_queue_Dequeue:2' shape=() dtype=int64>, <tf.Tensor 'random_shuffle_queue_Dequeue:3' shape=() dtype=int64>]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[<Thread(Thread-19, started daemon 140348491589376)>,\n",
       " <Thread(Thread-20, started daemon 140348482934528)>,\n",
       " <Thread(Thread-21, started daemon 140348474541824)>,\n",
       " <Thread(Thread-22, started daemon 140348078192384)>,\n",
       " <Thread(Thread-23, started daemon 140348069799680)>]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 读入 tfrecord 文件\n",
    "from data_process.sequence_example_helpers import count_records\n",
    "\n",
    "train_dataset_dir = '../data/train_tfrecords/'\n",
    "tfrecord_files = os.listdir(train_dataset_dir)\n",
    "tfrecord_file_names = [train_dataset_dir + file_name for file_name in tfrecord_files]\n",
    "from_tokens, to_tokens, from_tokens_len, to_tokens_len = get_padded_batch(tfrecord_file_names, batch_size)\n",
    "tf.train.start_queue_runners(sess=sess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(32, 42)\n",
      "(32, 43)\n",
      "(32,)\n",
      "(32,)\n"
     ]
    }
   ],
   "source": [
    "# 测试 get_batch\n",
    "input_tensors = [from_tokens, to_tokens, from_tokens_len, to_tokens_len]\n",
    "batch_inputs, batch_targets, batch_source_sequence_length, batch_target_sequence_length = sess.run(input_tensors)\n",
    "print(batch_inputs.shape)\n",
    "print(batch_targets.shape)\n",
    "print(batch_source_sequence_length.shape)\n",
    "print(batch_target_sequence_length.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3668/10000, cost = 4.93511, passed 2122.07s"
     ]
    }
   ],
   "source": [
    "max_epoch = 5\n",
    "_lr = 1e-3\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "n_tr_batches = 10000\n",
    "\n",
    "for epoch in xrange(max_epoch):\n",
    "    batch_indexs = np.random.permutation(n_tr_batches)  # shuffle the training data\n",
    "    _costs = 0.0\n",
    "    time0 = time.time()\n",
    "    for batch in xrange(n_tr_batches):\n",
    "        _global_step = sess.run(global_step)\n",
    "        # training\n",
    "        batch_id = batch_indexs[batch]\n",
    "        batch_inputs, batch_targets, batch_source_sequence_length, batch_target_sequence_length = sess.run(input_tensors)\n",
    "        train_fetches = [cost, train_op]\n",
    "        feed_dict = {inputs: batch_inputs, targets: batch_targets, source_sequence_length: batch_source_sequence_length, \n",
    "                    target_sequence_length: batch_target_sequence_length, lr:_lr}\n",
    "        _cost, _ = sess.run(train_fetches, feed_dict)  # the cost is the mean cost of one batch\n",
    "        _costs += _cost\n",
    "        sys.stdout.write('\\r%d/%d, cost = %g, passed %gs' % (batch, n_tr_batches, _cost, time.time() - time0))\n",
    "        sys.stdout.flush()\n",
    "    mean_cost = _costs / n_tr_batches\n",
    "    print('\\t train cost = %g, time cost %gs ' % (mean_cost, time.time() - time0))\n",
    "    save_path = saver.save(sess, '../ckpt/base_seq2seq.ckpt', global_step=sess.run(global_step))\n",
    "    print(save_path)\n",
    "\n",
    "# 19999/20000, cost = 4.46718, passed 11972.7s\t train cost = 4.55279, time cost 11972.7s \n",
    "# ../ckpt/base_seq2seq.ckpt-20000\n",
    "# 1491/20000, cost = 3.56721, passed 882.626s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### 导入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "ckpt_path = '../ckpt/'\n",
    "saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def source_to_seq(sentence, max_len=150):\n",
    "    \"\"\" \n",
    "    Args: \n",
    "        sentence: input English sentence, i.e. Eg: Nice to meet you.\n",
    "        max_len: the limited length of the input sentence.\n",
    "    Returns:\n",
    "        input_ids: The id sequence of the English sentence.\n",
    "        input_seq_len: The words number for the input sentence.\n",
    "    \"\"\"\n",
    "    from nltk import word_tokenize\n",
    "    words = word_tokenize(sentence.lower().decode('utf-8').strip())\n",
    "    input_ids = [dict_en_word2id.get(word, EN_UNK_ID) for word in words]\n",
    "    input_seq_len = len(input_ids)\n",
    "    if input_seq_len > max_len:\n",
    "        input_ids = input_ids[:max_len]\n",
    "        input_seq_len = max_len\n",
    "    return input_ids, input_seq_len\n",
    "\n",
    "def seq_to_target(ids):\n",
    "    \"\"\"decode the target id to Chinese sentence.\"\"\"\n",
    "    words = [dict_zh_id2word[each_id] for each_id in ids]\n",
    "    sentence = ' / '.join(words)\n",
    "    return sentence\n",
    "\n",
    "\n",
    "print(source_to_seq('Dear me, that is a long time not seeing you!'))\n",
    "print(seq_to_target([1,123, 32,443]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def translate(sentence):\n",
    "    input_ids, input_seq_len = source_to_seq(sentence)\n",
    "    input_data = np.asarray(input_ids * batch_size).reshape([batch_size, -1])\n",
    "    input_seq_len = np.asarray([input_seq_len] * batch_size).reshape([batch_size])\n",
    "    fetch = [predicting_logits]\n",
    "    # decode 最长长度默认使用输入的两倍长度。\n",
    "    feed_dict = {inputs: input_data, source_sequence_length: input_seq_len, target_sequence_length: input_seq_len*2}  # 预测两倍长度\n",
    "    answer_logits = sess.run(fetch, feed_dict=feed_dict)[0][0]\n",
    "    zh_sentence = seq_to_target(answer_logits)\n",
    "    return zh_sentence, len(answer_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "en_sent1 = 'Hi, rabbit. Long time no see! How is it going?'\n",
    "en_sent2 = 'Good morning, every one.'\n",
    "zh_sent, sent_len = translate(en_sent1)\n",
    "print('zh_sent len = %d' % sent_len)\n",
    "print(zh_sent)\n",
    "\n",
    "zh_sent, sent_len = translate(en_sent2)\n",
    "print('zh_sent len = %d' % sent_len)\n",
    "print(zh_sent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
